package com.diven.spark.ml.learn.transforms

import java.io.FileOutputStream

import com.diven.spark.ml.learn.core.{BaseSpark, BaseTest}
import javax.xml.transform.stream.StreamResult
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.feature.{Binarizer, ColumnCastType, ColumnFilter, Imputer, IndexToString, MaxAbsScaler, MinMaxScaler, QuantileDiscretizer, RFormula, StandardScaler, StringIndexer, VectorAssembler, OneHotEncoderEstimator => OneHotEncoder}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{DataTypes, MetadataBuilder, StructField}
import org.jpmml.model.JAXBUtil
import org.jpmml.sparkml.PMMLBuilder


object TransformPipeline extends BaseTest {

  def apply(): BaseSpark = new TransformPipeline()

}

class TransformPipeline extends BaseSpark {

  /**
   * 执行任务
   */
  override def execute(dataFrame: DataFrame): Unit = {
    //数据类型转换
    val columnCastType = new ColumnCastType().setCastColumns(Map("height"-> DataTypes.DoubleType, "weight"-> DataTypes.DoubleType, "income"-> DataTypes.DoubleType, "score"-> DataTypes.DoubleType, "qualified"-> DataTypes.IntegerType))
    val dataset = dataFrame

    val columnFilter = new ColumnFilter().setFilterColumns(Array("height", "weight", "income", "score", "age", "qualified"))

    //缺失值填充
    val imputer = new Imputer()
      .setInputCols(Array("height", "weight", "income", "score"))          //待变换的特征
      .setOutputCols(Array("height", "weight", "income", "score"))         //变换后的特征名称
      .setStrategy("mean")

    //二分类转换
    val binarizer = new Binarizer()
      .setInputCol("score")                //待变换的特征
      .setOutputCol("score_binarizer")     //变换后的特征名称
      .setThreshold(55d)                   //阈值

    //转换数据
    val vectorAssembler = new VectorAssembler()
      .setInputCols(Array("height", "weight", "income", "score_binarizer"))
      .setOutputCol("features")
      .setHandleInvalid("skip")

    //最大值最小值缩放
    val minMaxScaler = new MinMaxScaler()
      .setInputCol("features")                //待变换的特征
      .setOutputCol("features_minmax")        //变换后的特征名称
      .setMin(0.0)                            //特征的下边界
      .setMax(1.0)                            //特征的上边界

    //使用每个特征的最大值的绝对值将输入向量的特征值转换到[-1,1]之间
    val maxAbsScaler = new MaxAbsScaler()
      .setInputCol("features_minmax")         //待变换的特征
      .setOutputCol("features_minabs")        //变换后的特征名称

    //标准化每个特征使得其有统一的标准差以及（或者）均值为零//默认std
    val standardScaler = new StandardScaler()
      .setInputCol("features_minabs")         //待变换的特征
      .setOutputCol("features_standard")      //变换后的特征名称
      .setWithMean(true)                      //均值
      .setWithStd(true)                       //标准差

//    //独热编码
//    val oneHotEncoder = new OneHotEncoder()
//      .setInputCols(Array("income", "height"))          //待变换的特征
//      .setOutputCols(Array("income_onehot", "height_onehot")) //变换后的特征名称
//      .setHandleInvalid("keep")
//      .setDropLast(true)
//
//    val oneHotEncoder1 = new OneHotEncoder()
//      .setInputCols(Array("income", "height"))          //待变换的特征
//      .setOutputCols(Array("income_onehot_1", "height_onehot_1")) //变换后的特征名称
//      .setHandleInvalid("keep")
//      .setDropLast(true)

//    val quantileDiscretizer = new QuantileDiscretizer()
//      .setInputCols(Array("income", "height"))          //待变换的特征
//      .setOutputCols(Array("income_qd", "height_qd"))   //变换后的特征名称
//      .setNumBucketsArray(Array(5, 10))
//      .setHandleInvalid("keep")
//      .setRelativeError(0.001)


    // 对标签列建立索引
    val labelIndexer = new StringIndexer()
      .setInputCol("qualified")
      .setOutputCol("label")
      .fit(dataset)

    // 获取训练模型方式
    val dt = new DecisionTreeClassifier()
      .setLabelCol("label")
      .setFeaturesCol("features_standard")

    //将预测的值装换为字符标签
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedLabel")
      .setLabels(labelIndexer.labels)

    //Pipeline
    val pipeline = new Pipeline().setStages(Array(
      columnCastType,
//      columnFilter,
//      binarizer,
      imputer
//      vectorAssembler
//      minMaxScaler,
//      maxAbsScaler,
//      standardScaler,
//      labelIndexer,
//      dt,
//      labelConverter
//      oneHotEncoder,
//      oneHotEncoder1,
//      quantileDiscretizer
    ))

//    pipeline.fit(dataFrame).transform(dataFrame).select(
//      "income", "income_qd",
//      "income_onehot", "income_onehot_1",
//      "height", "height_qd",
//      "height_onehot", "height_onehot_1"
//    ).show(100, 1000)
//    pipeline.fit(dataset).transform(dataset).show(100, 1000)

//    pipeline.fit(dataset).transform(dataset).show()

//    val pmml = new PMMLBuilder(dataset.schema, pipeline.fit(dataset)).build()
//    JAXBUtil.marshalPMML(pmml, new StreamResult(new FileOutputStream("./spark-pmml-model-003.pmml")))
//
//    val age = dataFrame.schema("age")
//    val meta = new MetadataBuilder().withMetadata(age.metadata).putDouble("xx", 111).putString("xxxx", "cccc").build()
//new StructField()

    val model = pipeline.fit(dataFrame)

    var schema = model.transformSchema(dataset.schema)
//    var schema = columnFilter.transform(dataset).schema
    println("xxxxxxxxxxxxxxxxxxxxxxx")
    schema.foreach(item => {
      println(item.name, item.metadata.json)
    })
    println("xxxxxxxxxxxxxxxxxxxxxxx")

//    model.transform(dataset).show()


//    model.stages.foreach(x=> {
//      println(x.toString())
//      println(x.getClass)
//    })

    var aas = model.stages.filter(x=> !x.isInstanceOf[ColumnCastType] || !x.isInstanceOf[IndexToString])

    println("xxxxxxxxxxxxxxxxxxxxxxx")
    new Pipeline().setStages(aas).fit(dataset).stages.size
    println("xxxxxxxxxxxxxxxxxxxxxxx")
  }


}
